library(ggplot2)
library(gridExtra)
library(ggthemes)
library(pheatmap)
library(reshape2)
library(ggpubr)
library(viridis)
load("results/TM_dropseq_tSNE_withFACS_ggplot_obj.RData", verbose = TRUE)
## Loading objects:
##   df_tsne_dropseq
##   tm_cellTypes_colors
##   trainFACS
ClassifyError <- function(cellTypes_pred, cellTypes_test, cellTypes_train){
  if(length(cellTypes_pred)!=length(cellTypes_test)){
    stop("wrong input")
  }
  train_ref <- unique(cellTypes_train)
  test_ref <- unique(cellTypes_test)
  res <- sapply(1:length(cellTypes_pred), function(i){
    if(cellTypes_test[i]%in%train_ref){
      if(cellTypes_pred[i] %in% c("unassigned", "Unassigned")){
        "incorrectly unassigned"
      }else if (cellTypes_pred[i] == "intermediate"){
        "intermediate"
      }else{
        if(cellTypes_test[i] == cellTypes_pred[i]){
          "correct"
        }else if(grepl(cellTypes_test[i], cellTypes_pred[i])){
          "intermediate"
        }
        else{
          "misclassified"
        }
      }
    }else{
      if(cellTypes_pred[i] %in% c("unassigned","Unassigned")){
        "correctly unassigned"
      }else{
        "error assigned"
      }
    }
  })
  return(res)
}




facs_toDropseq_labels <- trainFACS$testRes$tm$pearson_WKNN_limma$predRes

len <- unlist(lapply(strsplit(facs_toDropseq_labels, "_"), length))



facs_toDropseq_labels[len > 1] <- "intermediate"
facs_toDropseq_labels[len > 10] <- "unassigned"

library(pheatmap)

classifyres <- ClassifyError(facs_toDropseq_labels, df_tsne_dropseq$cellTypes, names(trainFACS$trainRes$cutree_list[[9]]))


errClass <- c("correct", "correctly unassigned",  "intermediate", "incorrectly unassigned",
              "error assigned", "misclassified")
classifyres <- factor(classifyres, levels = errClass)



classifyResFreq <- table(classifyres)/length(classifyres)



df_tsne_dropseq$scClassify <- facs_toDropseq_labels


df_tsne_dropseq$scClassify <- as.factor(df_tsne_dropseq$scClassify)
df_tsne_dropseq$scClassify <- relevel(df_tsne_dropseq$scClassify, "unassigned")
df_tsne_dropseq$scClassify <- relevel(df_tsne_dropseq$scClassify, "intermediate")



g1 <- ggplot(df_tsne_dropseq, aes(x = tSNE1, y = tSNE2, color = cellTypes)) +
  ggrastr::geom_point_rast(alpha = 0.5, size = 0.5) +
  scale_color_manual(values = tm_cellTypes_colors[levels(df_tsne_dropseq$cellTypes)]) +
  theme_bw() +
  theme(aspect.ratio = 1, legend.position = "bottom",
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        panel.border = element_rect(colour = "black", fill = NA, 
                                    size = 1.2)) +
  labs(title = "Tabula Muris Dropseq (Original label)") 



g2 <- ggplot(df_tsne_dropseq, aes(x = tSNE1, y = tSNE2, color = scClassify)) +
  ggrastr::geom_point_rast(alpha = 0.5, size = 0.5) +
  scale_color_manual(values = tm_cellTypes_colors) +
  theme_bw() +
  theme(aspect.ratio = 1, legend.position = "right",
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        panel.border = element_rect(colour = "black", fill = NA, 
                                    size = 1.2)) +
  labs(title = "Tabula Muris Dropseq (scClassify Labelled by FACS)")


# Figure 5A
g1 + theme(legend.position = "none")

ggsave("figures/Figure5A_TM_original.pdf", width = 8, height = 9) 


g2 + theme(legend.position = "none")

ggsave("figures/Figure5A_TM_scClassify.pdf", width = 8, height = 9) 




tab_byTissue <- table(df_tsne_dropseq$tissue, classifyres)

tab_byCellTypes <- table(df_tsne_dropseq$cellTypes, classifyres)


res_composition <- data.frame(classifyres = classifyres,
                              tissue = df_tsne_dropseq$tissue)

library(viridis)

res_composition$tissue <- as.character(res_composition$tissue)
res_composition$tissue <- factor(res_composition$tissue, levels = names(table(df_tsne_dropseq$tissue))[order(table(df_tsne_dropseq$tissue))])

ggplot(res_composition, aes(x = tissue, fill = classifyres)) +
  geom_bar() +
  theme_bw() +
  theme(aspect.ratio = 0.5, legend.position = "right",
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        panel.border = element_rect(colour = "black", fill = NA, 
                                    size = 1.2)) +
  scale_fill_manual(values = plasma(11)[c(1, 2, 4, 6, 8, 10)]) +
  coord_flip()

ggsave("figures/Figure5B_TM_FACStoDrop_composition_barplot.pdf", width = 10, height = 6)




tab <- table((df_tsne_dropseq$scClassify), 
             droplevels(df_tsne_dropseq$cellTypes))




common_cellTypes <- intersect(unique(df_tsne_dropseq$cellTypes), unique(df_tsne_dropseq$scClassify))
otherCell_train <- as.character(unique(df_tsne_dropseq$scClassify)[!unique(df_tsne_dropseq$scClassify) %in% common_cellTypes])
otherCell_train <- otherCell_train[!otherCell_train %in% c("intermediate", "unassigned")]
otherCell_test <- as.character(unique(df_tsne_dropseq$cellTypes)[!unique(df_tsne_dropseq$cellTypes) %in% common_cellTypes])




library(RColorBrewer)

library(superheat)




toPlot_mat <- t(apply(tab, 2, function(x) x/sum(x)))*100

#pdf("figures/Figure5C_TM_FACS_drop_heatmap_order.pdf", width = 8, height = 8)
superheat(toPlot_mat[c(common_cellTypes, otherCell_test), c(common_cellTypes, otherCell_train, "intermediate", "unassigned")],
          yr = as.numeric(table(droplevels(df_tsne_dropseq$cellTypes))[c(common_cellTypes, otherCell_test)]),
          yr.axis.name = "Count (Original)",
          yr.plot.type = "bar",
          bottom.label.text.angle = 90,
          left.label.col = "white",
          bottom.label.col = "white",
          bottom.label.text.alignment = "right",
          left.label.text.alignment = "right",
          heat.pal = colorRampPalette((brewer.pal(n = 9, name =
                                                    "Reds")))(100),
          heat.pal.values = seq(0, 1, 0.01),
          heat.col.scheme = "red",
          left.label.text.size = 2,
          bottom.label.text.size = 2,
          grid.hline = FALSE,
          grid.vline = FALSE,
          legend.text.size = 10,
          # pretty.order.cols = TRUE,
          order.rows = nrow(toPlot_mat):1
)

#dev.off()




facs_trainRes <- trainFACS


level_list <- facs_trainRes$trainRes$cutree_list
E_list <- list()
for (i in 1:length(level_list)) {
  if (i != 1) {
    #   E_list[[i]] <- paste("0", level_list[[i]], sep = "_")
    # } else {
    parent <- level_list[[i-1]]
    E_list[[i]] <- paste(paste(i-1, parent, sep = ""), paste(i, level_list[[i]], sep = ""), sep = "_")
  }
}

V_list <- list()
for (i in 1:length(level_list)) {
  
  #   E_list[[i]] <- paste("0", level_list[[i]], sep = "_")
  # } else {
  
  V_list[[i]] <- paste(i, level_list[[i]], sep = "")
  names(V_list[[i]]) <- names(level_list[[i]])
}


# E_list[[length(E_list) + 1]] <-  paste(level_list[[length(level_list)]], rownames(data), sep = "_")

library(igraph)

g <- igraph::make_graph(unlist(strsplit(unique(unlist(lapply(E_list, sort))), "_")))



group_level <- length(level_list)-5

group_col <- (V_list[[group_level]])


names(group_col) <- V_list[[length(level_list)]]

V_group <- rep(NA, length(igraph::V(g)))
names(V_group) <- names(V(g))
V_group[V_list[[length(V_list)]]] <- group_col[V_list[[length(V_list)]]]
# igraph::V(g)$group[igraph::V(g)$name %in% V_list[[length(V_list)]]] <- as.factor(group_col)

V(g)$group <- V_group


V_name <- rep(NA, length(igraph::V(g)))
names(V_name) <- names(V(g))
V_name[V_list[[length(V_list)]]] <- names(V_list[[length(V_list)]])
# igraph::V(g)$group[igraph::V(g)$name %in% V_list[[length(V_list)]]] <- as.factor(group_col)

V(g)$name <- V_name

treePlot <- ggraph::ggraph(g, layout = 'dendrogram') +
  ggraph::geom_edge_diagonal(colour = "black") +
  ggraph::geom_node_text(aes(vjust = 1, hjust = -0.2, label = name), size = 3, alpha = 1) +
  ggraph::geom_node_point(aes(filter = leaf, alpha = 1, size = 3, colour = as.factor(V(g)$group))) +
  ggplot2::theme_void() +
  scale_colour_manual(values = unique(c(tableau_color_pal("Classic 20")(20), tableau_color_pal("Tableau 20")(20)))) +
  ggplot2::theme(legend.position = "none") +
  ggplot2::coord_flip() +
  NULL

treePlot

ggsave("figures/AppendixFigS6_TM_facs_HOPACH_tree.pdf", width = 10, height = 15)
load("results/lung_facs_tSNE_withCohenOnly_ggplot_obj.RData")

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = scClassify_byCohen)) +
  geom_point() +
  scale_color_manual(values = tm_cellTypes_colors[levels(df_toPlot_tm_facs$scClassify_byCohen)]) +
  theme_bw() +
  theme(aspect.ratio = 1, legend.position = "bottom", 
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Tabula Muris lung (scClassify Labelled by Cohen et al.)")

ggsave("figures/lung_facs_tSNE_byCohen.pdf", width = 14, height = 12)


ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["COL1A2",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "COL1A2")) +
  theme_bw() +
  xlim(c(6, 24)) + ylim(c(-20, 5)) + 
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Fibroblast: COL1A2") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["MFAP4",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "MFAP4")) +
  theme_bw() +
  xlim(c(6, 24)) + ylim(c(-20, 5)) + 
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Matrix Fibroblast: MFAP4") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["ASPN",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "ASPN")) +
  theme_bw() +
  xlim(c(6, 24)) + ylim(c(-20, 5)) + 
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Smooth Muscle Fibroblast: ASPN") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["GUCY1A3",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "GUCY1A3")) +
  theme_bw() +
  xlim(c(6, 24)) + ylim(c(-20, 5)) + 
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Pericytes: GUCY1A3") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["RETNLG",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "RETNLG")) +
  theme_bw() +
  xlim(c(7, 12)) + ylim(c(8, 13)) + 
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Neutrophil: RETNLG") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["CCL3",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "CCL3")) +
  xlim(c(7, 12)) + ylim(c(8, 13)) +   theme_bw() +
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Basophil: CCL3") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["CCL4",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "CCL4")) +
  xlim(c(7, 12)) + ylim(c(8, 13)) +   theme_bw() +
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Basophil: CCL4") 

ggplot(df_toPlot_tm_facs, aes(x = tSNE1, y = tSNE2, color = exprsMat_tm_facs["IFITM1",])) +
  geom_point(alpha = 1, size = 2) +
  scale_color_viridis_c(guide = guide_legend(title = "IFITM1")) +
  theme_bw() +
  xlim(c(7, 12)) + ylim(c(8, 13)) + 
  theme(aspect.ratio = 1, legend.position = "none", 
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        text = element_text(size = 14, face = "bold"), 
        panel.border = element_rect(colour = "black", fill = NA, size = 1.2),
        plot.title = element_text(hjust = 0.5)) +
  labs(title = "Marker of Basophil: IFITM1")